import json
import math
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set the seaborn style
sns.set_theme(style="whitegrid")


sort_order = {
    "Communal": 3,
    "Occupation+Communal": 7,
    "Occupation+Outlook": 6,
    "Occupation+Personality": 4,
    "Outlook": 2,
    "Personality": 1,
    "Occupation+Ideology": 5,
    "Ideology": 0,
    "Occupation": 8,
}

# Colors for each model
model_colors = {
    "GPT-3.5": sns.color_palette("muted")[0],  # Blue
    "GPT-4o": sns.color_palette("muted")[2],  # Red
    "Llama-3": sns.color_palette("muted")[3],  # Green
    "Claude-3.5-Sonnet": sns.color_palette("muted")[4],  # Purple
}

# Define the width of a single bar
bar_width = 0.2


def tanh_log(v):
    if v <= 0:
        raise ValueError(
            "Input value must be greater than 0 since log(v) is undefined for non-positive values."
        )
    return math.tanh(math.log(v))


def get_results_data(filepath):
    with open(filepath, "r") as f:
        data = json.load(f)
    return data



# Create a dict with model and bar heights
def create_model_bar_heights(category, data):
    models = list(data.keys())
    all_category_subfields = list(
        set().union(*(data[model][category].keys() for model in models))
    )
    all_category_subfields.sort(key=lambda x: sort_order.get(x, float("inf")))

    model_bar_heights_dict = {}
    for model in models:
        category_data = data[model][category]
        filtered_values = {
            k: tanh_log(v) for k, v in category_data.items() if v is not None
        }
        heights = [
            filtered_values.get(category, 0) for category in all_category_subfields
        ]
        model_bar_heights_dict[model] = heights

    return model_bar_heights_dict, all_category_subfields


def create_bar_plot(model_bar_heights_dict, ax, positions):
    for i, (model, heights) in enumerate(model_bar_heights_dict.items()):
        ax.bar(
            positions + i * bar_width,
            heights,
            bar_width,
            label=model,
            color=model_colors[model],
            edgecolor="black",  # Add borders
        )

    return ax


def plot_data(data, category):
    model_bar_heights_dict, all_category_subfields = create_model_bar_heights(category, data)
    num_models = len(model_bar_heights_dict.keys())
    num_subfields = len(all_category_subfields)
    # positions = np.arange(num_subfields) * (bar_width * (num_models + 1))

    cluster_gap = 0.5  # Increase this value to add more space between clusters of bars
    positions = np.arange(num_subfields) * (bar_width * num_models + cluster_gap)  # Adjust to add space between clusters


    # Set up the figure
    fig, ax = plt.subplots(figsize=(9, 6))
    ax = create_bar_plot(model_bar_heights_dict, ax, positions)

    # Customize the plot
    ax.set_ylabel("Bias score", fontsize=18, fontweight="bold")
    # ax.set_yticklabels(fontsize=16)
    ax.set_ylim(-1, 1)
    ax.set_xticks(positions + (num_models - 1) * bar_width / 2)
    ax.set_xticklabels(all_category_subfields, rotation=30, ha="right", fontsize=18)
    ax.legend(
        loc="upper center",
        bbox_to_anchor=(0.5, 1.2),
        fontsize=16,
        ncol=num_models,
        frameon=True,
    )

    # Remove vertical gridlines and keep horizontal gridlines
    ax.grid(axis="y", linestyle="--", linewidth=0.7)
    ax.grid(axis="x", visible=False)

    # Make the horizontal line at y=1 bold
    ax.axhline(y=0, color="cyan", linewidth=2)

    # Tight layout for better spacing
    plt.tight_layout()

    # Show the plot
    plt.show()

if __name__ == "__main__":
    data = get_results_data("template_negatives.json")
    plot_data(data, "Gender")
